{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "260629f9",
   "metadata": {},
   "source": [
    "# Rewrite-Retrieve-Read\n",
    "\n",
    "**Rewrite-Retrieve-Read** is a method proposed in the paper [Query Rewriting for Retrieval-Augmented Large Language Models](https://arxiv.org/pdf/2305.14283.pdf)\n",
    "\n",
    "> Because the original query can not be always optimal to retrieve for the LLM, especially in the real world... we first prompt an LLM to rewrite the queries, then conduct retrieval-augmented reading\n",
    "\n",
    "We show how you can easily do that with LangChain Expression Language"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eda93712",
   "metadata": {},
   "source": [
    "## Baseline\n",
    "\n",
    "Baseline RAG (**Retrieve-and-read**) can be done like the following:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1d2edbd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.chat_models import ChatOpenAI\n",
    "from langchain.prompts import ChatPromptTemplate\n",
    "from langchain.utilities import DuckDuckGoSearchAPIWrapper\n",
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_core.runnables import RunnablePassthrough"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "86a46aa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \"\"\"Answer the users question based only on the following context:\n",
    "\n",
    "<context>\n",
    "{context}\n",
    "</context>\n",
    "\n",
    "Question: {question}\n",
    "\"\"\"\n",
    "prompt = ChatPromptTemplate.from_template(template)\n",
    "\n",
    "model = ChatOpenAI(temperature=0)\n",
    "\n",
    "search = DuckDuckGoSearchAPIWrapper()\n",
    "\n",
    "\n",
    "def retriever(query):\n",
    "    return search.run(query)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8566d48e",
   "metadata": {},
   "outputs": [],
   "source": [
    "chain = (\n",
    "    {\"context\": retriever, \"question\": RunnablePassthrough()}\n",
    "    | prompt\n",
    "    | model\n",
    "    | StrOutputParser()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "5c57f9ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "simple_query = \"what is langchain?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "37c5f962",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"LangChain is a powerful and versatile Python library that enables developers and researchers to create, experiment with, and analyze language models and agents. It simplifies the development of language-based applications by providing a suite of features for artificial general intelligence. It can be used to build chatbots, perform document analysis and summarization, and streamline interaction with various large language model providers. LangChain's unique proposition is its ability to create logical links between one or more language models, known as Chains. It is an open-source library that offers a generic interface to foundation models and allows prompt management and integration with other components and tools.\""
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chain.invoke(simple_query)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "23bdb9bd",
   "metadata": {},
   "source": [
    "While this is fine for well formatted queries, it can break down for more complicated queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8df6a814",
   "metadata": {},
   "outputs": [],
   "source": [
    "distracted_query = \"man that sam bankman fried trial was crazy! what is langchain?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "16d7db64",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Based on the given context, there is no information provided about \"langchain.\"'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chain.invoke(distracted_query)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b4f8b93",
   "metadata": {},
   "source": [
    "This is because the retriever does a bad job with these \"distracted\" queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3439d8dc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Business She\\'s the star witness against Sam Bankman-Fried. Her testimony was explosive Gary Wang, who co-founded both FTX and Alameda Research, said Bankman-Fried directed him to change a... The Verge, following the trial\\'s Oct. 4 kickoff: \"Is Sam Bankman-Fried\\'s Defense Even Trying to Win?\". CBS Moneywatch, from Thursday: \"Sam Bankman-Fried\\'s Lawyer Struggles to Poke ... Sam Bankman-Fried, FTX\\'s founder, responded with a single word: \"Oof.\". Less than a year later, Mr. Bankman-Fried, 31, is on trial in federal court in Manhattan, fighting criminal charges ... July 19, 2023. A U.S. judge on Wednesday overruled objections by Sam Bankman-Fried\\'s lawyers and allowed jurors in the FTX founder\\'s fraud trial to see a profane message he sent to a reporter days ... Sam Bankman-Fried, who was once hailed as a virtuoso in cryptocurrency trading, is on trial over the collapse of FTX, the financial exchange he founded. Bankman-Fried is accused of...'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "retriever(distracted_query)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7eb748ac",
   "metadata": {},
   "source": [
    "## Rewrite-Retrieve-Read Implementation\n",
    "\n",
    "The main part is a rewriter to rewrite the search query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "88ae702e",
   "metadata": {},
   "outputs": [],
   "source": [
    "template = \"\"\"Provide a better search query for \\\n",
    "web search engine to answer the given question, end \\\n",
    "the queries with ’**’. Question: \\\n",
    "{x} Answer:\"\"\"\n",
    "rewrite_prompt = ChatPromptTemplate.from_template(template)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "184e1bcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain import hub\n",
    "\n",
    "rewrite_prompt = hub.pull(\"langchain-ai/rewrite\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a4c23d40",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Provide a better search query for web search engine to answer the given question, end the queries with ’**’.  Question {x} Answer:\n"
     ]
    }
   ],
   "source": [
    "print(rewrite_prompt.template)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f55cd010",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parser to remove the `**`\n",
    "\n",
    "\n",
    "def _parse(text):\n",
    "    return text.strip(\"**\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c9c34bef",
   "metadata": {},
   "outputs": [],
   "source": [
    "rewriter = rewrite_prompt | ChatOpenAI(temperature=0) | StrOutputParser() | _parse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "fb17fb3d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'What is the definition and purpose of Langchain?'"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rewriter.invoke({\"x\": distracted_query})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "f83edb09",
   "metadata": {},
   "outputs": [],
   "source": [
    "rewrite_retrieve_read_chain = (\n",
    "    {\n",
    "        \"context\": {\"x\": RunnablePassthrough()} | rewriter | retriever,\n",
    "        \"question\": RunnablePassthrough(),\n",
    "    }\n",
    "    | prompt\n",
    "    | model\n",
    "    | StrOutputParser()\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "43096322",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Based on the given context, LangChain is an open-source framework designed to simplify the creation of applications using large language models (LLMs). It enables LLM models to generate responses based on up-to-date online information and simplifies the organization of large volumes of data for easy access by LLMs. LangChain offers a standard interface for chains, integrations with other tools, and end-to-end chains for common applications. It is a robust library that streamlines interaction with various LLM providers. LangChain\\'s unique proposition is its ability to create logical links between one or more LLMs, known as Chains. It is an AI framework with features that simplify the development of language-based applications and offers a suite of features for artificial general intelligence. However, the context does not provide any information about the \"sam bankman fried trial\" mentioned in the question.'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rewrite_retrieve_read_chain.invoke(distracted_query)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59874b4f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
